import fitz
import os
import numpy as np
import json
from openai import OpenAI
import re
from tqdm import tqdm
from datetime import datetime

# 设置 API 变量
os.environ["OPENAI_API_KEY"] = "sk-aswwzdvvimeybiiokqebixpkhbmcftlbgkubssfuodifqjcf"
# 初始化OpenAI客户端
client = OpenAI(
    base_url="https://api.siliconflow.cn/v1",
    api_key=os.getenv("OPENAI_API_KEY")  # 获取环境变量
)

def extract_text_from_pdf(pdf_path):
    mypdf = fitz.open(pdf_path)
    all_text = ""  # Initialize an empty string to store the extracted text
    for page_num in range(mypdf.page_count):
        page = mypdf[page_num]  # Get the page
        text = page.get_text("text")  # Extract text from the page
        all_text += text  # Append the extracted text to the all_text string
    return all_text  # Return the extracted text

def chunk_text(text, n, overlap):
    chunks = []
    for i in range(0, len(text), n - overlap):
        chunks.append(text[i:i + n])
    return chunks  # Return the list of text chunks

class SimpleVectorStore:
    def __init__(self):
        self.vectors = []  # List to store embedding vectors
        self.texts = []  # List to store original text chunks
        self.metadata = []  # List to store metadata for each text chunk

    def add_item(self, text, embedding, metadata=None):
        self.vectors.append(np.array(embedding))  # Convert and store the embedding
        self.texts.append(text)  # Store the original text
        self.metadata.append(metadata or {})  # Store metadata (empty dict if None)

    def similarity_search(self, query_embedding, k=5, filter_func=None):
        if not self.vectors:
            return []  # Return empty list if vector store is empty
        query_vector = np.array(query_embedding)
        similarities = []
        for i, vector in enumerate(self.vectors):
            if filter_func and not filter_func(self.metadata[i]):
                continue
            similarity = np.dot(query_vector, vector) / (np.linalg.norm(query_vector) * np.linalg.norm(vector))
            similarities.append((i, similarity))
        similarities.sort(key=lambda x: x[1], reverse=True)
        results = []
        for i in range(min(k, len(similarities))):
            idx, score = similarities[i]
            results.append({
                "text": self.texts[idx],
                "metadata": self.metadata[idx],
                "similarity": score,
                # Use pre-existing relevance score from metadata if available, otherwise use similarity
                "relevance_score": self.metadata[idx].get("relevance_score", score)
            })
        return results

def create_embeddings(text, model="BAAI/bge-m3"):
    input_text = text if isinstance(text, list) else [text]
    response = client.embeddings.create(
        model=model,
        input=input_text
    )
    if isinstance(text, str):
        return response.data[0].embedding
    return [item.embedding for item in response.data]

def get_user_feedback(query, response, relevance, quality, comments=""):
    return {
        "query": query,
        "response": response,
        "relevance": int(relevance),
        "quality": int(quality),
        "comments": comments,
        "timestamp": datetime.now().isoformat()
    }

def store_feedback(feedback, feedback_file="feedback_data.json"):
    with open(feedback_file, "a") as f:
        json.dump(feedback, f)
        f.write("\n")

def load_feedback_data(feedback_file="feedback_data.json"):
    feedback_data = []
    try:
        with open(feedback_file, "r") as f:
            for line in f:
                if line.strip():
                    feedback_data.append(json.loads(line.strip()))
    except FileNotFoundError:
        print("No feedback data file found. Starting with empty feedback.")
    return feedback_data

def process_document(pdf_path, chunk_size=1000, chunk_overlap=200):
    print("Extracting text from PDF...")
    extracted_text = extract_text_from_pdf(pdf_path)
    print("Chunking text...")
    chunks = chunk_text(extracted_text, chunk_size, chunk_overlap)
    print(f"Created {len(chunks)} text chunks")
    print("Creating embeddings for chunks...")
    chunk_embeddings = create_embeddings(chunks)
    store = SimpleVectorStore()
    for i, (chunk, embedding) in enumerate(zip(chunks, chunk_embeddings)):
        store.add_item(
            text=chunk,
            embedding=embedding,
            metadata={
                "index": i,  # Position in original document
                "source": pdf_path,  # Source document path
                "relevance_score": 1.0,  # Initial relevance score (will be updated with feedback)
                "feedback_count": 0  # Counter for feedback received on this chunk
            }
        )
    print(f"Added {len(chunks)} chunks to the vector store")
    return chunks, store

def assess_feedback_relevance(query, doc_text, feedback):
    system_prompt = """You are an AI system that determines if a past feedback is relevant to a current query and document.
    Answer with ONLY 'yes' or 'no'. Your job is strictly to determine relevance, not to provide explanations."""
    user_prompt = f"""
        Current query: {query}
        Past query that received feedback: {feedback['query']}
        Document content: {doc_text[:500]}... [truncated]
        Past response that received feedback: {feedback['response'][:500]}... [truncated]
        Is this past feedback relevant to the current query and document? (yes/no)
        """
    response = client.chat.completions.create(
        model="Qwen/Qwen2-1.5B-Instruct",
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ],
        temperature=0  # Use temperature=0 for consistent, deterministic responses
    )
    answer = response.choices[0].message.content.strip().lower()
    return 'yes' in answer  # Return True if the answer contains 'yes'

def adjust_relevance_scores(query, results, feedback_data):
    if not feedback_data:
        return results
    print("Adjusting relevance scores based on feedback history...")
    for i, result in enumerate(results):
        document_text = result["text"]
        relevant_feedback = []
        for feedback in feedback_data:
            is_relevant = assess_feedback_relevance(query, document_text, feedback)
            if is_relevant:
                relevant_feedback.append(feedback)
        if relevant_feedback:
            avg_relevance = sum(f['relevance'] for f in relevant_feedback) / len(relevant_feedback)
            modifier = 0.5 + (avg_relevance / 5.0)
            original_score = result["similarity"]
            adjusted_score = original_score * modifier
            result["original_similarity"] = original_score  # Preserve the original score
            result["similarity"] = adjusted_score  # Update the primary score
            result["relevance_score"] = adjusted_score  # Update the relevance score
            result["feedback_applied"] = True  # Flag that feedback was applied
            result["feedback_count"] = len(relevant_feedback)  # Number of feedback entries used
            print(
                f"  Document {i + 1}: Adjusted score from {original_score:.4f} to {adjusted_score:.4f} based on {len(relevant_feedback)} feedback(s)")
    results.sort(key=lambda x: x["similarity"], reverse=True)
    return results

def fine_tune_index(current_store, chunks, feedback_data):
    print("Fine-tuning index with high-quality feedback...")
    good_feedback = [f for f in feedback_data if f['relevance'] >= 4 and f['quality'] >= 4]
    if not good_feedback:
        print("No high-quality feedback found for fine-tuning.")
        return current_store
    new_store = SimpleVectorStore()
    for i in range(len(current_store.texts)):
        new_store.add_item(
            text=current_store.texts[i],
            embedding=current_store.vectors[i],
            metadata=current_store.metadata[i].copy()
        )
    for feedback in good_feedback:
        enhanced_text = f"Question: {feedback['query']}\nAnswer: {feedback['response']}"
        embedding = create_embeddings(enhanced_text)
        new_store.add_item(
            text=enhanced_text,
            embedding=embedding,
            metadata={
                "type": "feedback_enhanced",  # Mark as derived from feedback
                "query": feedback["query"],  # Store original query for reference
                "relevance_score": 1.2,  # Boost initial relevance to prioritize these items
                "feedback_count": 1,  # Track feedback incorporation
                "original_feedback": feedback  # Preserve complete feedback record
            }
        )
        print(f"Added enhanced content from feedback: {feedback['query'][:50]}...")
    print(f"Fine-tuned index now has {len(new_store.texts)} items (original: {len(chunks)})")
    return new_store

def generate_response(query, context, model="Qwen/Qwen2-1.5B-Instruct"):
    system_prompt = """You are a helpful AI assistant. Answer the user's question based only on the provided context. 
    If you cannot find the answer in the context, state that you don't have enough information."""
    user_prompt = f"""
        Context:
        {context}
        Question: {query}
        Please provide a comprehensive answer based only on the context above."""
    response = client.chat.completions.create(
        model=model,
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ],
        temperature=0  # Use temperature=0 for consistent, deterministic responses
    )
    return response.choices[0].message.content

def rag_with_feedback_loop(query, vector_store, feedback_data, k=5, model="Qwen/Qwen2-1.5B-Instruct"):
    print(f"\n=== Processing query with feedback-enhanced RAG ===")
    print(f"Query: {query}")
    query_embedding = create_embeddings(query)
    results = vector_store.similarity_search(query_embedding, k=k)
    adjusted_results = adjust_relevance_scores(query, results, feedback_data)
    retrieved_texts = [result["text"] for result in adjusted_results]
    context = "\n\n---\n\n".join(retrieved_texts)
    print("Generating response...")
    response = generate_response(query, context, model)
    result = {
        "query": query,
        "retrieved_documents": adjusted_results,
        "response": response
    }
    print("\n=== Response ===")
    print(response)
    return result

def full_rag_workflow(pdf_path, query, feedback_data=None, feedback_file="feedback_data.json", fine_tune=False):
    if feedback_data is None:
        feedback_data = load_feedback_data(feedback_file)
        print(f"Loaded {len(feedback_data)} feedback entries from {feedback_file}")
    chunks, vector_store = process_document(pdf_path)
    if fine_tune and feedback_data:
        vector_store = fine_tune_index(vector_store, chunks, feedback_data)
    result = rag_with_feedback_loop(query, vector_store, feedback_data)
    print("\n=== Would you like to provide feedback on this response? ===")
    print("Rate relevance (1-5, with 5 being most relevant):")
    relevance = input()
    print("Rate quality (1-5, with 5 being highest quality):")
    quality = input()
    print("Any comments? (optional, press Enter to skip)")
    comments = input()
    feedback = get_user_feedback(
        query=query,
        response=result["response"],
        relevance=int(relevance),
        quality=int(quality),
        comments=comments
    )
    store_feedback(feedback, feedback_file)
    print("Feedback recorded. Thank you!")
    return result

def evaluate_feedback_loop(pdf_path, test_queries, reference_answers=None):
    print("=== Evaluating Feedback Loop Impact ===")
    temp_feedback_file = "temp_evaluation_feedback.json"
    feedback_data = []
    print("\n=== ROUND 1: NO FEEDBACK ===")
    round1_results = []
    for i, query in enumerate(test_queries):
        print(f"\nQuery {i + 1}: {query}")
        chunks, vector_store = process_document(pdf_path)
        result = rag_with_feedback_loop(query, vector_store, [])
        round1_results.append(result)
        if reference_answers and i < len(reference_answers):
            similarity_to_ref = calculate_similarity(result["response"], reference_answers[i])
            relevance = max(1, min(5, int(similarity_to_ref * 5)))
            quality = max(1, min(5, int(similarity_to_ref * 5)))
            feedback = get_user_feedback(
                query=query,
                response=result["response"],
                relevance=relevance,
                quality=quality,
                comments=f"Synthetic feedback based on reference similarity: {similarity_to_ref:.2f}"
            )
            feedback_data.append(feedback)
            store_feedback(feedback, temp_feedback_file)
    print("\n=== ROUND 2: WITH FEEDBACK ===")
    round2_results = []
    chunks, vector_store = process_document(pdf_path)
    vector_store = fine_tune_index(vector_store, chunks, feedback_data)

    for i, query in enumerate(test_queries):
        print(f"\nQuery {i + 1}: {query}")
        result = rag_with_feedback_loop(query, vector_store, feedback_data)
        round2_results.append(result)
    comparison = compare_results(test_queries, round1_results, round2_results, reference_answers)
    if os.path.exists(temp_feedback_file):
        os.remove(temp_feedback_file)
    return {
        "round1_results": round1_results,
        "round2_results": round2_results,
        "comparison": comparison
    }

def calculate_similarity(text1, text2):
    embedding1 = create_embeddings(text1)
    embedding2 = create_embeddings(text2)
    vec1 = np.array(embedding1)
    vec2 = np.array(embedding2)
    similarity = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
    return similarity

def compare_results(queries, round1_results, round2_results, reference_answers=None):
    print("\n=== COMPARING RESULTS ===")
    system_prompt = """You are an expert evaluator of RAG systems. Compare responses from two versions:
        1. Standard RAG: No feedback used
        2. Feedback-enhanced RAG: Uses a feedback loop to improve retrieval

        Analyze which version provides better responses in terms of:
        - Relevance to the query
        - Accuracy of information
        - Completeness
        - Clarity and conciseness
    """
    comparisons = []
    for i, (query, r1, r2) in enumerate(zip(queries, round1_results, round2_results)):
        comparison_prompt = f"""
            Query: {query}
            Standard RAG Response:
            {r1["response"]}
            Feedback-enhanced RAG Response:
            {r2["response"]}
            """
        if reference_answers and i < len(reference_answers):
            comparison_prompt += f"""
            Reference Answer:
            {reference_answers[i]}
            """
        comparison_prompt += """
        Compare these responses and explain which one is better and why.
        Focus specifically on how the feedback loop has (or hasn't) improved the response quality.
        """
        response = client.chat.completions.create(
            model="Qwen/Qwen2-1.5B-Instruct",
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": comparison_prompt}
            ],
            temperature=0
        )
        comparisons.append({
            "query": query,
            "analysis": response.choices[0].message.content
        })
        print(f"\nQuery {i + 1}: {query}")
        print(f"Analysis: {response.choices[0].message.content[:200]}...")
    return comparisons

pdf_path = "data/AI_Information.pdf"
test_queries = [
    "What is a neural network and how does it function?",
    #################################################################################
    ### Commented out queries to reduce the number of queries for testing purposes ###
    # "Describe the process and applications of reinforcement learning.",
    # "What are the main applications of natural language processing in today's technology?",
    # "Explain the impact of overfitting in machine learning models and how it can be mitigated."
]

reference_answers = [
    "A neural network is a series of algorithms that attempt to recognize underlying relationships in a set of data "
    "through a process that mimics the way the human brain operates. It consists of layers of nodes, with each node "
    "representing a neuron. Neural networks function by adjusting the weights of connections between nodes based on "
    "the error of the output compared to the expected result.",
    ############################################################################################
    #### Commented out reference answers to reduce the number of queries for testing purposes ###

    #     "Reinforcement learning is a type of machine learning where an agent learns to make decisions by performing actions in an environment to maximize cumulative reward. It involves exploration, exploitation, and learning from the consequences of actions. Applications include robotics, game playing, and autonomous vehicles.",
    #     "The main applications of natural language processing in today's technology include machine translation, sentiment analysis, chatbots, information retrieval, text summarization, and speech recognition. NLP enables machines to understand and generate human language, facilitating human-computer interaction.",
    #     "Overfitting in machine learning models occurs when a model learns the training data too well, capturing noise and outliers. This results in poor generalization to new data, as the model performs well on training data but poorly on unseen data. Mitigation techniques include cross-validation, regularization, pruning, and using more training data."
]
evaluation_results = evaluate_feedback_loop(
    pdf_path=pdf_path,
    test_queries=test_queries,
    reference_answers=reference_answers
)
comparisons = evaluation_results['comparison']
print("\n=== FEEDBACK IMPACT ANALYSIS ===\n")
for i, comparison in enumerate(comparisons):
    print(f"Query {i+1}: {comparison['query']}")
    print(f"\nAnalysis of feedback impact:")
    print(comparison['analysis'])
    print("\n" + "-"*50 + "\n")
round_responses = [evaluation_results[f'round{round_num}_results'] for round_num in range(1, len(evaluation_results) - 1)]
response_lengths = [[len(r["response"]) for r in round] for round in round_responses]
print("\nResponse length comparison (proxy for completeness):")
avg_lengths = [sum(lengths) / len(lengths) for lengths in response_lengths]
for round_num, avg_len in enumerate(avg_lengths, start=1):
    print(f"Round {round_num}: {avg_len:.1f} chars")
if len(avg_lengths) > 1:
    changes = [(avg_lengths[i] - avg_lengths[i-1]) / avg_lengths[i-1] * 100 for i in range(1, len(avg_lengths))]
    for round_num, change in enumerate(changes, start=2):
        print(f"Change from Round {round_num-1} to Round {round_num}: {change:.1f}%")

